{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型信息： GPT2LMHeadModel(\n",
      "  (transformer): GPT2Model(\n",
      "    (wte): Embedding(50257, 768)\n",
      "    (wpe): Embedding(1024, 768)\n",
      "    (drop): Dropout(p=0.1, inplace=False)\n",
      "    (h): ModuleList(\n",
      "      (0): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (1): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (2): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (3): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (4): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (5): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (6): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (7): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (8): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (9): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (10): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "      (11): GPT2Block(\n",
      "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): GPT2Attention(\n",
      "          (c_attn): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): GPT2MLP(\n",
      "          (c_fc): Conv1D()\n",
      "          (c_proj): Conv1D()\n",
      "          (act): NewGELUActivation()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
      "  )\n",
      "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
      ")\n",
      "分词器信息： PreTrainedTokenizer(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True)})\n",
      "词汇表大小: 50257\n",
      "部分词汇示例: ['parent', 'Art', 'pack', 'Ġdiplom', 'rets']\n"
     ]
    }
   ],
   "source": [
    "import torch # 导入torch\n",
    "from transformers import GPT2Tokenizer # 导入GPT2分词器\n",
    "from transformers import GPT2LMHeadModel # 导入GPT2语言模型\n",
    "model_name = \"gpt2\"  # 也可以选择其他模型，如\"gpt2-medium\"、\"gpt2-large\"等\n",
    "tokenizer = GPT2Tokenizer.from_pretrained(model_name) # 加载分词器\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # 判断是否有可用GPU\n",
    "model = GPT2LMHeadModel.from_pretrained(model_name).to(device) # 将模型加载到设备上（CPU或GPU）\n",
    "vocab = tokenizer.get_vocab() # 获取词汇表\n",
    "\n",
    "print(\"模型信息：\", model)\n",
    "print(\"分词器信息：\",tokenizer)\n",
    "print(\"词汇表大小:\", len(vocab))\n",
    "print(\"部分词汇示例:\", (list(vocab.keys())[8000:8005]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example 1:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-08 22:44:09.419588: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: hi, how are you?<|endoftext|>\n",
      "Target: i am doing well, thank you. how about you?<|endoftext|>\n",
      "Example 2:\n",
      "Input: i am good, thanks for asking. what can you do?<|endoftext|>\n",
      "Target: i am an ai language model. i can help you answer questions.<|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import Dataset  # 导入Pytorch的Dataset\n",
    "# 自定义ChatDataset类，继承自Pytorch的Dataset类\n",
    "class ChatDataset(Dataset):\n",
    "    def __init__(self, file_path, tokenizer, vocab):\n",
    "        self.tokenizer = tokenizer  # 分词器\n",
    "        self.vocab = vocab  # 词汇表\n",
    "        # 加载数据并处理，将处理后的输入数据和目标数据赋值给input_data和target_data\n",
    "        self.input_data, self.target_data = self.load_and_process_data(file_path)\n",
    "    # 定义加载和处理数据的方法\n",
    "    def load_and_process_data(self, file_path):        \n",
    "        with open(file_path, \"r\") as f: # 读取文件内容\n",
    "            lines = f.readlines()\n",
    "        input_data, target_data = [], []        \n",
    "        for i, line in enumerate(lines): # 遍历文件的每一行            \n",
    "            if line.startswith(\"User:\"): # 如以\"User:\"开头,分词，移除\"User: \"前缀，并将张量转换为列表\n",
    "                tokens = self.tokenizer(line.strip()[6:], return_tensors=\"pt\")[\"input_ids\"].tolist()[0]\n",
    "                tokens = tokens + [tokenizer.eos_token_id]  # 添加结束符\n",
    "                input_data.append(torch.tensor(tokens, dtype=torch.long)) # 添加到input_data中\n",
    "            elif line.startswith(\"AI:\"): # 如以\"AI:\"开头，分词，移除\"AI: \"前缀，并将张量转换为列表\n",
    "                tokens = self.tokenizer(line.strip()[4:], return_tensors=\"pt\")[\"input_ids\"].tolist()[0]\n",
    "                tokens = tokens + [tokenizer.eos_token_id]  # 添加结束符\n",
    "                target_data.append(torch.tensor(tokens, dtype=torch.long)) # 添加到target_data中\n",
    "        return input_data, target_data\n",
    "    # 定义数据集的长度，即input_data的长度\n",
    "    def __len__(self):\n",
    "        return len(self.input_data)\n",
    "    # 定义获取数据集中指定索引的数据的方法\n",
    "    def __getitem__(self, idx):\n",
    "        return self.input_data[idx], self.target_data[idx]\n",
    "\n",
    "file_path = \"chat.txt\" # 加载chat.txt数据集\n",
    "chat_dataset = ChatDataset(file_path, tokenizer, vocab) # 创建ChatDataset对象，传入文件、分词器和词汇表\n",
    "for i in range(2): # 打印数据集中前2个数据示例\n",
    "    input_example, target_example = chat_dataset[i]\n",
    "    print(f\"Example {i + 1}:\")\n",
    "    print(\"Input:\", tokenizer.decode(input_example))\n",
    "    print(\"Target:\", tokenizer.decode(target_example))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input batch tensor size: torch.Size([2, 21])\n",
      "Target batch tensor size: torch.Size([2, 21])\n",
      "Input batch tensor:\n",
      "tensor([[ 5303,   837,   703,   389,   345,  5633, 50256, 50256, 50256, 50256,\n",
      "         50256, 50256, 50256, 50256, 50256, 50256],\n",
      "        [   72,   716,   922,   837,  5176,   329,  4737,   764,   644,   460,\n",
      "           345,   466,  5633, 50256, 50256, 50256]])\n",
      "Target batch tensor:\n",
      "tensor([[   72,   716,  1804,   880,   837,  5875,   345,   764,   703,   546,\n",
      "           345,  5633, 50256, 50256, 50256, 50256],\n",
      "        [   72,   716,   281,   257,    72,  3303,  2746,   764,  1312,   460,\n",
      "          1037,   345,  3280,  2683,   764, 50256]])\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader # 导入Dataloader\n",
    "tokenizer.pad_token = '<pad>' # 为分词器添加pad token\n",
    "tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<pad>')\n",
    "# 定义pad_sequence函数，用于将一批序列补齐到相同长度\n",
    "def pad_sequence(sequences, padding_value=0, length=None):\n",
    "    # 计算最大序列长度，如果length参数未提供，则使用输入序列中的最大长度\n",
    "    max_length = max(len(seq) for seq in sequences) if length is None else length    \n",
    "    # 创建一个具有适当形状的全零张量，用于存储补齐后的序列\n",
    "    result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)    \n",
    "    # 遍历序列，将每个序列的内容复制到结果张量中\n",
    "    for i, seq in enumerate(sequences):\n",
    "        end = len(seq)\n",
    "        result[i, :end] = seq[:end]\n",
    "    return result\n",
    "\n",
    "# 定义collate_fn函数，用于将一个批次的数据整理成适当的形状\n",
    "def collate_fn(batch):\n",
    "    # 从批次中分离源序列和目标序列\n",
    "    sources, targets = zip(*batch)    \n",
    "    # 计算批次中的最大序列长度\n",
    "    max_length = max(max(len(s) for s in sources), max(len(t) for t in targets))    \n",
    "    # 使用pad_sequence函数补齐源序列和目标序列\n",
    "    sources = pad_sequence(sources, padding_value=tokenizer.pad_token_id, length=max_length)\n",
    "    targets = pad_sequence(targets, padding_value=tokenizer.pad_token_id, length=max_length)    \n",
    "    # 返回补齐后的源序列和目标序列\n",
    "    return sources, targets\n",
    "\n",
    "# 创建Dataloader\n",
    "chat_dataloader = DataLoader(chat_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)\n",
    "\n",
    "# 检查Dataloader输出\n",
    "for input_batch, target_batch in chat_dataloader:\n",
    "    print(\"Input batch tensor size:\", input_batch.size())\n",
    "    print(\"Target batch tensor size:\", target_batch.size())\n",
    "    break\n",
    "for input_batch, target_batch in chat_dataloader:\n",
    "    print(\"Input batch tensor:\")\n",
    "    print(input_batch)\n",
    "    print(\"Target batch tensor:\")\n",
    "    print(target_batch)\n",
    "    break\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0020, Batch: 0003, cost = 0.413135\n",
      "Epoch: 0040, Batch: 0003, cost = 0.237347\n",
      "Epoch: 0060, Batch: 0003, cost = 1.088322\n",
      "Epoch: 0080, Batch: 0003, cost = 0.048271\n",
      "Epoch: 0100, Batch: 0003, cost = 0.001409\n",
      "Epoch: 0120, Batch: 0003, cost = 0.000533\n",
      "Epoch: 0140, Batch: 0003, cost = 0.001399\n",
      "Epoch: 0160, Batch: 0003, cost = 0.001586\n",
      "Epoch: 0180, Batch: 0003, cost = 0.000865\n",
      "Epoch: 0200, Batch: 0003, cost = 0.000385\n",
      "Epoch: 0220, Batch: 0003, cost = 0.000295\n",
      "Epoch: 0240, Batch: 0003, cost = 0.000166\n",
      "Epoch: 0260, Batch: 0003, cost = 0.000278\n",
      "Epoch: 0280, Batch: 0003, cost = 0.000117\n",
      "Epoch: 0300, Batch: 0003, cost = 0.000136\n",
      "Epoch: 0320, Batch: 0003, cost = 0.000085\n",
      "Epoch: 0340, Batch: 0003, cost = 0.000135\n",
      "Epoch: 0360, Batch: 0003, cost = 0.000020\n",
      "Epoch: 0380, Batch: 0003, cost = 0.000102\n",
      "Epoch: 0400, Batch: 0003, cost = 0.000052\n",
      "Epoch: 0420, Batch: 0003, cost = 0.000047\n",
      "Epoch: 0440, Batch: 0003, cost = 0.000059\n",
      "Epoch: 0460, Batch: 0003, cost = 0.000013\n",
      "Epoch: 0480, Batch: 0003, cost = 0.000012\n",
      "Epoch: 0500, Batch: 0003, cost = 0.000056\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "# 定义损失函数，忽略pad_token_id对应的损失值\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n",
    "# 定义优化器\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "# 进行100个epoch的训练\n",
    "for epoch in range(500):\n",
    "    # 遍历数据加载器中的批次\n",
    "    for batch_idx, (input_batch, target_batch) in enumerate(chat_dataloader):        \n",
    "        optimizer.zero_grad() # 梯度清零   \n",
    "        input_batch, target_batch = input_batch.to(device), target_batch.to(device) # 将输入和目标批次移至设备（CPU或GPU）   \n",
    "        outputs = model(input_batch) # 前向传播\n",
    "        logits = outputs.logits  # 获取logits        \n",
    "        loss = criterion(logits.view(-1, len(vocab)), target_batch.view(-1)) # 计算损失\n",
    "        loss.backward() # 反向传播        \n",
    "        optimizer.step()# 更新参数    \n",
    "    if (epoch + 1) % 20 == 0: # 每200个epoch打印一次损失值\n",
    "        print(f'Epoch: {epoch + 1:04d}, cost = {loss:.6f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试 1:\n",
      "User: what is the weather like today?\n",
      "AI:  application model application name for application application application application application application application model application application model application application application model application application model application application application model application application model application application application application application application application application application application application application application application application application program application application application\n",
      "\n",
      "测试 2:\n",
      "User: hi, how are you?\n",
      "AI:  how you have always be doing how you, you, you know me how you, you know how you are about doing how you, if need know how you need help if need know how you want help if need need need help if need need help\n",
      "\n",
      "测试 3:\n",
      "User: can you recommend a good book?\n",
      "AI:  har harper har har har har har har har har har har har har har har har har har har har har har har har har har har har har is har har har har har har har har har har har har har har har har har har\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def generate_text_beam_search(model, input_str, max_len=50, beam_width=5):\n",
    "    model.eval()  # 将模型设置为评估模式（不计算梯度）\n",
    "    # 对输入字符串进行编码，并将其转换为 PyTorch 张量，然后将其移动到相应的设备上（例如 GPU）\n",
    "    input_tokens = tokenizer.encode(input_str, return_tensors=\"pt\").to(device)    \n",
    "    # 初始化候选序列列表，包含当前输入序列和其对数概率得分（我们从0开始）\n",
    "    candidates = [(input_tokens, 0.0)]    \n",
    "    # 禁用梯度计算，以加速预测过程\n",
    "    with torch.no_grad():\n",
    "        # 迭代生成最大长度的序列\n",
    "        for _ in range(max_len):\n",
    "            new_candidates = []            \n",
    "            # 对于每个候选序列\n",
    "            for candidate, candidate_score in candidates:\n",
    "                # 使用模型进行预测\n",
    "                outputs = model(candidate)\n",
    "                # 获取输出 logits\n",
    "                logits = outputs.logits[:, -1, :]\n",
    "                # 获取对数概率得分的 top-k 值（即 beam_width）及其对应的 token\n",
    "                scores, next_tokens = torch.topk(logits, beam_width, dim=-1)\n",
    "                final_results = []\n",
    "                # 遍历 top-k token 及其对应的得分\n",
    "                for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):\n",
    "                    # 在当前候选序列中添加新的 token\n",
    "                    new_candidate = torch.cat((candidate, next_token.unsqueeze(0).unsqueeze(0)), dim=-1)\n",
    "                    # 更新候选序列的得分\n",
    "                    new_score = candidate_score - score.item()                    \n",
    "                    # 如果新的 token 是结束符（eos_token），则将该候选序列添加到最终结果中\n",
    "                    if next_token.item() == tokenizer.eos_token_id:\n",
    "                        final_results.append((new_candidate, new_score))\n",
    "                    # 否则，将新的候选序列添加到新候选序列列表中\n",
    "                    else:\n",
    "                        new_candidates.append((new_candidate, new_score))            \n",
    "            # 从新候选序列列表中选择得分最高的 top-k 个序列\n",
    "            candidates = sorted(new_candidates, key=lambda x: x[1])[:beam_width]    \n",
    "    # 选择得分最高的候选序列\n",
    "    best_candidate, _ = sorted(candidates, key=lambda x: x[1])[0]    \n",
    "    # 将输出 token 转换回文本字符串\n",
    "    output_str = tokenizer.decode(best_candidate[0])    \n",
    "    # 移除输入字符串并修复空格问题\n",
    "    input_len = len(tokenizer.encode(input_str))\n",
    "    output_str = tokenizer.decode(best_candidate.squeeze()[input_len:])    \n",
    "    return output_str\n",
    "\n",
    "test_inputs = [\n",
    "    \"what is the weather like today?\",\n",
    "    \"hi, how are you?\",\n",
    "    \"can you recommend a good book?\"\n",
    "]\n",
    "\n",
    "for i, input_str in enumerate(test_inputs, start=1):\n",
    "    generated_text = generate_text_beam_search(model, input_str)\n",
    "    print(f\"测试 {i}:\")\n",
    "    print(f\"User: {input_str}\")\n",
    "    print(f\"AI: {generated_text}\")\n",
    "\n",
    "test_inputs = [\n",
    "    \"what is the weather like today?\",\n",
    "    \"hi , how are you?\",\n",
    "    \"can you recommend a good book?\"\n",
    "]\n",
    "\n",
    "for i, input_str in enumerate(test_inputs, start=1):\n",
    "    generated_text = generate_text_beam_search(model, input_str)\n",
    "    print(f\"测试 {i}:\")\n",
    "    print(f\"User: {input_str}\")\n",
    "    print(f\"AI: {generated_text}\")\n",
    "    print()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
